/*
 * memory/virtual.c
 *
 * Copyright (C) 2019 Aleksandar Andrejevic <theflash@sdf.lonestar.org>
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as
 * published by the Free Software Foundation, either version 3 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <memory.h>
#include <log.h>

#define SPACE_BLOCK_PAGES PAGE_NUMBER(PAGE_ALIGN_UP(MEMORY_MAX_BLOCKS * sizeof(memory_block_t)))
#define SPACE_BITMAP_PAGES PAGE_NUMBER(PAGE_ALIGN_UP(MEMORY_MAX_BLOCKS / 8))

static address_space_t kernel_space;

address_space_t *memory_lower_space = NULL;
address_space_t *const memory_upper_space = &kernel_space;
memory_block_t *user_memory_blocks = NULL;
dword_t *user_memory_block_bitmap = NULL;

static memory_block_t *block_create(address_space_t *space)
{
    dword_t bit;
    for (bit = 0; bit < MEMORY_MAX_BLOCKS && test_bit(space->block_bitmap, bit); bit++) continue;
    if (bit == MEMORY_MAX_BLOCKS) return NULL;

    set_bit(space->block_bitmap, bit);
    memory_block_t *block = &space->blocks[bit];
    memset(block, 0, sizeof(*block));

    return block;
}

static void block_free(address_space_t *space, memory_block_t *block)
{
    clear_bit(space->block_bitmap, block - space->blocks);
}

static memory_block_t *block_find(address_space_t *space, uintptr_t address)
{
    avl_node_t *node = avl_tree_lower_bound(&space->by_addr_tree, &address);
    memory_block_t *block = CONTAINER_OF(node, memory_block_t, by_addr_node);
    ASSERT(node && address - block->address < block->size);
    return block;
}

static memory_block_t *find_smallest_free_block(avl_node_t *root, size_t min_size)
{
    if (!root) return NULL;

    memory_block_t *block = CONTAINER_OF(root, memory_block_t, by_size_node);
    if (block->size < min_size) return find_smallest_free_block(root->right, min_size);

    memory_block_t *left = find_smallest_free_block(root->left, min_size);
    memory_block_t *right = find_smallest_free_block(root->right, min_size);
    if (left && left->size < block->size) block = left;
    if (right && right->size < block->size) block = right;

    return block;
}

static memory_block_t *find_preferred_free_block(avl_node_t *root, uintptr_t address, size_t min_size)
{
    if (!root) return NULL;

    memory_block_t *block = CONTAINER_OF(root, memory_block_t, by_size_node);
    if (address >= block->address && (address - block->address) < block->size)
    {
        return (block->size - (address - block->address) >= min_size) ? block : NULL;
    }

    memory_block_t *left = find_preferred_free_block(root->left, address, min_size);
    memory_block_t *right = find_preferred_free_block(root->right, address, min_size);
    return left ? left : right;
}

static int compare(const void *a, const void *b)
{
    const size_t first = *(const size_t*)a;
    const size_t second = *(const size_t*)b;

    if (first < second) return -1;
    else if (first > second) return 1;
    else return 0;
}

memory_block_t *memory_get_block_for_address(void *address)
{
    address_space_t *space = (intptr_t)address < 0 ? memory_upper_space : memory_lower_space;
    if (!space) return NULL;
    return block_find(space, PAGE_NUMBER((uintptr_t)address));
}

sysret_t memory_allocate(address_space_t *space,
                         void **address,
                         size_t size,
                         memory_flags_t flags,
                         memory_section_t *section,
                         page_num_t section_offset)
{
    if (!space) return ERR_NOTFOUND;
    sysret_t ret = ERR_NOMEMORY;
    uintptr_t preferred_address = PAGE_NUMBER((uintptr_t)*address);

    size = PAGE_NUMBER((uintptr_t)*address + size - 1) - preferred_address + 1;
    if (!size) return ERR_SUCCESS;

    lock_acquire_smart(&space->lock);

    memory_block_t *block;
    if (*address) block = find_preferred_free_block(space->by_size_tree.root, preferred_address, size);
    else block = find_smallest_free_block(space->by_size_tree.root, size);
    if (!block) goto cleanup;

    if (*address)
    {
        size_t offset = preferred_address - block->address;
        memory_block_t *new_block = block_create(space);
        new_block->flags = MEMORY_FLAG_FREE;
        new_block->address = preferred_address;
        new_block->size = block->size - offset;

        avl_tree_change_key(&space->by_size_tree, &block->by_size_node, &offset);
        avl_tree_insert(&space->by_addr_tree, &new_block->by_addr_node);
        avl_tree_insert(&space->by_size_tree, &new_block->by_size_node);

        block = new_block;
    }

    if (size < block->size)
    {
        memory_block_t *new_block = block_create(space);
        new_block->flags = MEMORY_FLAG_FREE;
        new_block->address = block->address + size;
        new_block->size = block->size - size;

        avl_tree_change_key(&space->by_size_tree, &block->by_size_node, &size);
        avl_tree_insert(&space->by_addr_tree, &new_block->by_addr_node);
        avl_tree_insert(&space->by_size_tree, &new_block->by_size_node);
    }

    block->flags = flags;
    block->section = section;
    block->section_offset = section_offset;
    *address = (void*)(block->address * PAGE_SIZE);
    ret = ERR_SUCCESS;

cleanup:
    lock_release(&space->lock);
    return ret;
}

sysret_t memory_free(address_space_t *space, void *address)
{
    sysret_t ret;
    if (!space) return ERR_NOTFOUND;
    lock_acquire_smart(&space->lock);

    memory_block_t *block = block_find(space, PAGE_NUMBER((uintptr_t)address));
    if (!block || (block->flags & MEMORY_FLAG_FREE))
    {
        ret = ERR_INVALID;
        goto cleanup;
    }

    block->flags = MEMORY_FLAG_FREE;
    block->section = NULL;
    block->section_offset = 0;

    for (;;)
    {
        avl_node_t *next = avl_get_next_node(&block->by_addr_node);
        if (!next) break;

        memory_block_t *next_block = CONTAINER_OF(next, memory_block_t, by_addr_node);
        if (!(next_block->flags & MEMORY_FLAG_FREE)) break;

        size_t new_size = block->size + next_block->size;
        avl_tree_change_key(&space->by_size_tree, &block->by_size_node, &new_size);

        avl_tree_remove(&space->by_addr_tree, &next_block->by_addr_node);
        avl_tree_remove(&space->by_size_tree, &next_block->by_size_node);
        block_free(space, next_block);
    }

    for (;;)
    {
        avl_node_t *prev = avl_get_previous_node(&block->by_addr_node);
        if (!prev) break;

        memory_block_t *prev_block = CONTAINER_OF(prev, memory_block_t, by_addr_node);
        if (!(prev_block->flags & MEMORY_FLAG_FREE)) break;

        size_t new_size = prev_block->size + block->size;
        avl_tree_change_key(&space->by_size_tree, &prev_block->by_size_node, &new_size);

        avl_tree_remove(&space->by_addr_tree, &block->by_addr_node);
        avl_tree_remove(&space->by_size_tree, &block->by_size_node);
        block_free(space, block);
        block = prev_block;
    }

    ret = ERR_SUCCESS;

cleanup:
    lock_release(&space->lock);
    return ret;
}

sysret_t memory_view_area(address_space_t *space, void **address, const area_t *area, memory_flags_t flags)
{
    ASSERT((space == memory_lower_space && (*address == NULL || (intptr_t)*address > 0))
           || (space == memory_upper_space && (*address == NULL || (intptr_t)*address < 0)));

    sysret_t ret = memory_allocate(space, address, area->count * PAGE_SIZE, 0, NULL, 0);
    if (ret != ERR_SUCCESS) return ret;

    ret = memory_map_area(memory_default_table, area, *address, flags);
    if (ret != ERR_SUCCESS) memory_free(space, *address);
    return ret;
}

sysret_t memory_pin_buffer(const void *virtual, void **pinned, size_t size, bool_t lock_contents)
{
    void *address = *pinned;
    size_t num_pages = PAGE_NUMBER(PAGE_ALIGN_UP(size));
    sysret_t ret = memory_allocate(memory_upper_space, &address, num_pages, 0, NULL, 0);
    if (ret != ERR_SUCCESS) return ret;

    uintptr_t source = PAGE_ALIGN((uintptr_t)virtual);
    uintptr_t destination = (uintptr_t)address;
    memory_flags_t flags = MEMORY_FLAG_ACCESSIBLE | MEMORY_FLAG_STICKY;
    if (!lock_contents) flags |= MEMORY_FLAG_WRITABLE;

    for (size_t i = 0; i < num_pages; i++)
    {
        page_t *page = memory_get_page_mapping(memory_default_table, (void*)(source + i * PAGE_SIZE));
        if (!page)
        {
            ret = ERR_BADPTR;
            break;
        }

        ret = memory_map_page(memory_default_table, page, (void*)(destination + i * PAGE_SIZE), flags);
        if (ret != ERR_SUCCESS) break;
    }

    if (ret == ERR_SUCCESS) *pinned = (void*)((uintptr_t)address | PAGE_OFFSET((uintptr_t)virtual));
    else memory_free(memory_upper_space, address);

    return ret;
}

sysret_t syscall_alloc_memory(handle_t process, void **address, size_t size, memory_flags_t flags)
{
    return ERR_NOSYSCALL;
}

sysret_t syscall_free_memory(handle_t process, void *address)
{
    return ERR_NOSYSCALL;
}

sysret_t syscall_commit_memory(handle_t process, void *address, size_t size)
{
    return ERR_NOSYSCALL;
}

sysret_t syscall_uncommit_memory(handle_t process, void *address, size_t size)
{
    return ERR_NOSYSCALL;
}

sysret_t syscall_query_memory(handle_t process, void *address, memory_block_info_t *info)
{
    return ERR_NOSYSCALL;
}

sysret_t syscall_protect_memory(handle_t process, void *address, size_t size, memory_flags_t flags)
{
    return ERR_NOSYSCALL;
}

sysret_t syscall_read_memory(handle_t process, void *address, void *buffer, dword_t size)
{
    return ERR_NOSYSCALL;
}

sysret_t syscall_write_memory(handle_t process, void *address, void *buffer, dword_t size)
{
    return ERR_NOSYSCALL;
}

void memory_init_virtual(const area_t *kernel_area)
{
    kernel_space.blocks = memory_request_metadata_space(SPACE_BLOCK_PAGES, PAGE_SIZE);
    kernel_space.block_bitmap = memory_request_metadata_space(SPACE_BITMAP_PAGES, PAGE_SIZE);
    uintptr_t global_metadata = memory_metadata_base;

    user_memory_blocks = memory_request_metadata_space(SPACE_BLOCK_PAGES, PAGE_SIZE);
    user_memory_block_bitmap = memory_request_metadata_space(SPACE_BITMAP_PAGES, PAGE_SIZE);

    AVL_TREE_INIT(&kernel_space.by_addr_tree, memory_block_t, by_addr_node, address, compare);
    AVL_TREE_INIT(&kernel_space.by_size_tree, memory_block_t, by_size_node, size, compare);
    lock_init(&kernel_space.lock);

    page_t *initial_block_page = memory_acquire_page(MIN_PHYS_ADDR_BITS, MAX_PHYS_ADDR_BITS, PAGE_SIZE);
    ASSERT(initial_block_page != NULL);

    page_t *initial_bitmap_page = memory_acquire_page(MIN_PHYS_ADDR_BITS, MAX_PHYS_ADDR_BITS, PAGE_SIZE);
    ASSERT(initial_bitmap_page != NULL);

    sysret_t ret = memory_map_page(memory_default_table,
                                   initial_block_page,
                                   kernel_space.blocks,
                                   MEMORY_FLAG_ACCESSIBLE | MEMORY_FLAG_WRITABLE | MEMORY_FLAG_STICKY);
    if (ret != ERR_SUCCESS) KERNEL_CRASH("Memory block mapping failed");

    ret = memory_map_page(memory_default_table,
                          initial_bitmap_page,
                          kernel_space.block_bitmap,
                          MEMORY_FLAG_ACCESSIBLE | MEMORY_FLAG_WRITABLE | MEMORY_FLAG_STICKY);
    if (ret != ERR_SUCCESS) KERNEL_CRASH("Memory block bitmap mapping failed");

    memory_block_t *root_block = block_create(&kernel_space);
    root_block->address = PAGE_NUMBER((uintptr_t)INTPTR_MAX + 1);
    root_block->size = PAGE_NUMBER(MEMORY_METADATA_TOP) - root_block->address;
    root_block->flags = MEMORY_FLAG_FREE;
    avl_tree_insert(&kernel_space.by_addr_tree, &root_block->by_addr_node);
    avl_tree_insert(&kernel_space.by_size_tree, &root_block->by_size_node);

    if (!(kernel_space.root_page_table = memory_create_page_table()))
        KERNEL_CRASH("Cannot create kernel page directory");

    void *address = (void*)memory_metadata_base;
    ret = memory_allocate(memory_upper_space,
                          &address,
                          PAGE_NUMBER(global_metadata) - PAGE_NUMBER(memory_metadata_base),
                          MEMORY_FLAG_ACCESSIBLE | MEMORY_FLAG_WRITABLE,
                          NULL,
                          0);
    ASSERT(ret == ERR_SUCCESS);

    address = (void*)global_metadata;
    ret = memory_allocate(memory_upper_space,
                          &address,
                          PAGE_NUMBER(MEMORY_METADATA_TOP) - PAGE_NUMBER(global_metadata),
                          MEMORY_FLAG_ACCESSIBLE | MEMORY_FLAG_WRITABLE | MEMORY_FLAG_STICKY,
                          NULL,
                          0);
    ASSERT(ret == ERR_SUCCESS);

    ret = memory_load_shadow_table(kernel_space.root_page_table);
    if (ret != ERR_SUCCESS) KERNEL_CRASH("Cannot mount the kernel space");

    address = (void*)((uintptr_t)INTPTR_MAX + 1);
    ret = memory_allocate(memory_upper_space,
                          &address,
                          kernel_area->count * PAGE_SIZE,
                          MEMORY_FLAG_ACCESSIBLE
                          | MEMORY_FLAG_WRITABLE
                          | MEMORY_FLAG_EXECUTABLE
                          | MEMORY_FLAG_STICKY,
                          NULL,
                          0);
    ASSERT(ret == ERR_SUCCESS);

    ret = memory_map_area(memory_shadow_table,
                          kernel_area,
                          address,
                          MEMORY_FLAG_ACCESSIBLE
                          | MEMORY_FLAG_WRITABLE
                          | MEMORY_FLAG_EXECUTABLE
                          | MEMORY_FLAG_STICKY);
    ASSERT(ret == ERR_SUCCESS);

    memory_unload_shadow_table();
    memory_load_default_table(kernel_space.root_page_table);
}
